def short_reg_def(var, half):
    return ".reg.u16 " + ", ".join(["{}{}{}".format(var, i, half) for i in range(8)])

def short_regs_def():
    return [
        short_reg_def("a", "l"),
        short_reg_def("a", "h"),
        short_reg_def("b", "l"),
        short_reg_def("b", "h"),
    ]

def word_regs_def():
    return [
        ".reg.u32 " + ", ".join(["t{}".format(i) for i in range(8)])
    ]

def load_half(var, i):
    return "mov.b32 {{{{ {0}{1}l, {0}{1}h }}}}, {{{0}[{1}]}}".format(var, i)

def load_halves(var):
    return [load_half(var, i) for i in range(8)]

def mul_first_pass_row_zero():
    return ["mul.wide.u16 {{z[{0}]}}, a{0}h, b0l".format(i) for i in range(8)]

def mul_row(i, ahalf, bhalf):
    row = [
        "mul.wide.u16 t{j}, a{j}{ahalf}, b{i}{bhalf}".format(i=i, j=j, ahalf=ahalf, bhalf=bhalf)
        for j in range(8)
    ] + [
        "add.cc.u32 {{z[{0}]}}, {{z[{0}]}}, t0".format(i)
    ] + [
        "addc.cc.u32 {{z[{0}]}}, {{z[{0}]}}, t{1}".format(i + j, j) for j in range(1, 8)
    ]
    return row

def mul_first_pass_row_double_zero():
    return mul_first_pass_row_zero() + mul_row(0, "l", "h") + [
        "addc.u32 {{z[{0}]}}, 0, 0".format(8)
    ]

def mul_first_pass_row_double(i):
    return mul_row(i, "h", "l") + mul_row(i, "l", "h") + [
        "addc.u32 {{z[{0}]}}, 0, 0".format(i + 8)
    ]

def mul_second_pass_row(i, half):
    z = i
    if half == "h":
        z += 1

    res = [
        "mul.wide.u16 t{j}, a{j}{half}, b{i}{half}".format(i=i, j=j, half=half)
        for j in range(8)
    ] + [
        "add.cc.u32 {{z[{0}]}}, {{z[{0}]}}, t0".format(z)
    ] + [
        "addc.cc.u32 {{z[{0}]}}, {{z[{0}]}}, t{1}".format(z + j, j)
        for j in range(1, 8 if z < 8 else 7)
    ] + [
        "addc.cc.u32 {{z[{0}]}}, {{z[{0}]}}, 0".format(j) for j in range(z + 8, 15)
    ]

    if z < 8:
        res.append("addc.u32 {z[15]}, {z[15]}, 0")
    else:
        res.append("addc.u32 {z[15]}, {z[15]}, t7")
    return res

def shift_16():
    return [
        "shf.l.clamp.b32 {{z[{1}]}}, {{z[{0}]}}, {{z[{1}]}}, 16".format(i - 1, i) for i in range(15, 0, -1)
    ] + [
        "shl.b32 {z[0]}, {z[0]}, 16"
    ]

def mul_first_pass():
    res = mul_first_pass_row_double_zero()

    for i in range(1, 8):
        res += mul_first_pass_row_double(i)
    return res

def preamble():
    return word_regs_def() + \
        short_regs_def() + \
        load_halves("a") + \
        load_halves("b")

def mul():
    res = preamble() + mul_first_pass_row_double_zero()
    for i in range(1, 8):
        res.extend(mul_first_pass_row_double(i))
    res.extend(shift_16())

    for i in range(0, 8):
        res.extend(mul_second_pass_row(i, "l"))
        res.extend(mul_second_pass_row(i, "h"))

    return res

def output():
    asm_lines = mul()
    z = ["%{}".format(i) for i in range(16)]
    a = ["%{}".format(i) for i in range(16, 24)]
    b = ["%{}".format(i) for i in range(24, 32)]

    print("#pragma once")
    print()
    print("#include <cstdint>")
    print()
    print("__device__ __forceinline__ void mul_wide(uint32_t z[16], const uint32_t a[8], const uint32_t b[8])")
    print("{")
    print("    asm(R\"({")

    for line in asm_lines:
        print(" " * 7, line.format(a=a, b=b, z=z) + ";")

    print("   })\"")
    print("   :", ", ".join(["\"=r\"(z[{}])".format(i) for i in range(8)]) + ",")
    print("    ", ", ".join(["\"=r\"(z[{}])".format(i) for i in range(8, 16)]))
    print("   :", ", ".join(["\"r\"(a[{}])".format(i) for i in range(8)]) + ",")
    print("    ", ", ".join(["\"r\"(b[{}])".format(i) for i in range(8)]))
    print("    );")
    print("}")

if __name__ == "__main__":
    output()